package edu.berkeley.nlp.math;

import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.util.CallbackFunction;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.Logger;
//import edu.berkeley.nlp.util.optionparser.GlobalOptionParser;
import edu.berkeley.nlp.util.Option;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random;

/**
 * User: aria42
 * Date: Mar 10, 2009
 */
public class StochasticObjectiveOptimizer<I> {

  Collection<I> items;
  List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns;
  Regularizer regularizer;
  double initAlpha = 0.5;
  double upAlphaMult = 1.1;
  double downAlphaMult = 0.5;
  Object weightLock = new Object();
  double[] weights;
  double alpha ;
  CallbackFunction iterDoneCallback;
  boolean printProgress = true;
  Random rand ;

  @Option public int randSeed = 0;
  @Option public boolean doAveraging = false;
  @Option public boolean shuffleData = false;


  double[] sumWeightVector;
  int numUpdates ;

  public StochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult)
  {
    this(initAlpha,upAlphaMult,downAlphaMult,true);
  }

  public StochasticObjectiveOptimizer(double initAlpha, double upAlphaMult, double downAlphaMult, boolean printProgress)
  {
    this.initAlpha = initAlpha;
    this.upAlphaMult = upAlphaMult;
    this.downAlphaMult = downAlphaMult;
    this.printProgress = printProgress;
//    GlobalOptionParser.fillOptions(this);
    rand = new Random(randSeed);
  }

  public void setIterationCallback(CallbackFunction iterDoneCallback) {
    this.iterDoneCallback = iterDoneCallback;
  }

  // Do a pass through the data of SGD
  class GradMapper implements SimpleMapper<I> {
    double val = 0.0;
    ObjectiveItemDifferentiableFunction<I> itemFn;
    GradMapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
      this.itemFn = itemFn;  
    }
    public void map(I elem) {
      double[] localWeights;
      synchronized (weightLock) {
        localWeights = DoubleArrays.clone(weights);
      }
      double[] localGrad = new double[dimension()];
      itemFn.setWeights(localWeights);
      val += itemFn.update(elem,localGrad);
      if (regularizer != null) {
        val += regularizer.update(localWeights,localGrad,1.0/items.size());
      }
      synchronized (weightLock) {
        DoubleArrays.addInPlace(weights,localGrad, -alpha);
        DoubleArrays.addInPlace(sumWeightVector, weights);
        numUpdates++;
      }
    }
  }

  // Compute the function value for a fixed set of parameters
  class ValMapper implements SimpleMapper<I> {
    double val = 0.0;
    ObjectiveItemDifferentiableFunction<I> itemFn;
    ValMapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
      this.itemFn = itemFn;  
    }
    public void map(I elem) {
      val += itemFn.update(elem, null);
      val += regularizer.val(weights,1.0/items.size());
    }
  }
  
  private double doIter() {
    List<GradMapper> gradMappers = new ArrayList<GradMapper>();
    for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
      gradMappers.add(new GradMapper(itemFn));
    }
    List<I> shuffledItems = shuffleData ?  CollectionUtils.shuffle(items,rand) : new ArrayList<I>(items);
    AsynchronousMapper.doMapping(shuffledItems,gradMappers);
    
//    List<ValMapper> valMappers = new ArrayList<ValMapper>();
//    for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
//      valMappers.add(new ValMapper(itemFn));
//    }
//    AsynchronousMapper.doMapping(items,valMappers);
    
    double val = 0.0;
    for (GradMapper mapper : gradMappers) {
      val += mapper.val;
    }
    return val;
  }

  public double[] minimize(double[] initWeights,
                           int numIters,
                           Collection<I> items,
                           List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns,
                           Regularizer regularizer)
  {
    this.items = items;
    this.itemFns = itemFns;
    this.numUpdates = 0;
    this.regularizer = regularizer;
    alpha = initAlpha;
    weights = DoubleArrays.clone(initWeights);
    sumWeightVector = DoubleArrays.constantArray(0.0, weights.length);
    double lastVal = Double.POSITIVE_INFINITY;    
    for (int iter = 0; iter < numIters; iter++) {
      double val = doIter();
      double alphaMult = val < lastVal ? upAlphaMult : downAlphaMult;
      alpha *= alphaMult;
      lastVal = val;
      if (printProgress) {
        Logger.logs("[StochasticObjectiveOptimizer] Ended Iteration %d with value %.5f",iter+1,val);
        Logger.logs("[StochasticObjectiveOptimizer] New Alpha: %.5f (scaled by %.5f)",alpha,alphaMult);
      }
      if (iterDoneCallback != null) {
       iterDoneCallback.callback(iter,doAveraging ? avgWeightVector() : weights,val,alpha);
      }
      
      if (alpha < initAlpha*Math.pow(10.0, -2.0)) {
    	  Logger.logs("[StochasticObjectiveOptimizer] alpha %.5f below tolerance %.5f, saying converged", alpha, initAlpha*Math.pow(10.0, -2.0));
    	  break;
      }
    }
    return doAveraging ? avgWeightVector() : weights;
  }

  private double[] avgWeightVector() {
    double[] avgWeights = DoubleArrays.clone(sumWeightVector);
    DoubleArrays.scale(avgWeights,1.0/numUpdates);
    return avgWeights;
  }

  public int dimension() {
    return itemFns.get(0).dimension();
  }
}
